import torch
import torch.nn as nn
import math

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4):
        super(LoRALayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        
        # 
        self.A = nn.Parameter(torch.randn(in_features, rank))  # 
        self.B = nn.Parameter(torch.randn(rank, out_features))  # 
        
        # init weights
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, x):
        # 
        return torch.matmul(torch.matmul(x, self.A), self.B)